import numpy as np
import cv2
import argparse
import os
import sys

path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(path, "common/"))
sys.path.append(os.path.join(path, "common/acllite"))

from constants import ACL_MEM_MALLOC_HUGE_FIRST, ACL_MEMCPY_DEVICE_TO_DEVICE, IMG_EXT
from acllite_model import AclLiteModel
from acllite_image import AclLiteImage
from acllite_resource import AclLiteResource


def letterbox(img, new_shape=(416, 416), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True):
    shape = img.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:
        r = min(r, 1.0)

    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, 64), np.mod(dh, 64)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2
    if shape[::-1] != new_unpad:  # resize
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return img, ratio, (dw, dh)


def clip_coords(boxes, img_shape):
    # Clip bounding xyxy bounding boxes to image shape (height, width)
    boxes[:, 0].clip(0, img_shape[1])  # x1
    boxes[:, 1].clip(0, img_shape[0])  # y1
    boxes[:, 2].clip(0, img_shape[1])  # x2
    boxes[:, 3].clip(0, img_shape[0])  # y2


def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
    # Rescale coords (xyxy) from img1_shape to img0_shape
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    coords[:, [0, 2]] -= pad[0]  # x padding
    coords[:, [1, 3]] -= pad[1]  # y padding
    coords[:, :4] /= gain
    clip_coords(coords, img0_shape)
    return coords


class Detector():

    def __init__(self, opt):
        super(Detector, self).__init__()
        self.img_size = opt.img_size
        self.threshold = opt.conf_thres
        self.iou_thres = opt.iou_thres
        self.stride = 1
        self.weights = opt.weights
        self.init_model()
        self.names = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
                      "traffic light",
                      "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
                      "cow",
                      "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
                      "frisbee",
                      "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
                      "surfboard",
                      "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana",
                      "apple",
                      "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
                      "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
                      "cell phone",
                      "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
                      "teddy bear",
                      "hair drier", "toothbrush"]

        self.detected_labels = ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
                                "traffic light",
                                "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse",
                                "sheep", "cow",
                                "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
                                "suitcase", "frisbee",
                                "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
                                "skateboard", "surfboard",
                                "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
                                "banana", "apple",
                                "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
                                "chair", "couch",
                                "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote",
                                "keyboard", "cell phone",
                                "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
                                "scissors", "teddy bear",
                                "hair drier", "toothbrush"]

    def init_model(self):
        # sess = onnxruntime.InferenceSession(self.weights,
        #                                     providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider',
        #                                                'CPUExecutionProvider'])
        # self.input_name = sess.get_inputs()[0].name
        network = AclLiteModel(self.weights)

        #         output_names = []
        #         for i in range(len(sess.get_outputs())):
        #             print('output shape:', sess.get_outputs()[i].name)
        #             output_names.append(sess.get_outputs()[i].name)

        #         self.output_name = sess.get_outputs()[0].name
        #         print('input name:%s, output name:%s' % (self.input_name, self.output_name))
        #         input_shape = sess.get_inputs()[0].shape
        #         print('input_shape:', input_shape)
        self.m = network

    def preprocess(self, img):
        img0 = img.copy()
        img = letterbox(img, new_shape=self.img_size)[0]
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img).astype(np.float32)
        img /= 255.0  # 图像归一化
        img = np.expand_dims(img, axis=0)
        assert len(img.shape) == 4

        return img0, img

    def centerpoint2xywh(self, x):
        # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
        y = np.copy(x)
        y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
        y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
        y[:, 2] = x[:, 2]  # bottom right x
        y[:, 3] = x[:, 3]  # bottom right y
        return y

    def detect(self, im):
        im0, img = self.preprocess(im)
        W, H = img.shape[2:]

        # pred = self.m.run(None, {self.input_name: img})[0]
        pred = self.m.execute([img, ])[0]

        pred = pred.astype(np.float32)
        pred = np.squeeze(pred, axis=0)

        # 过滤conf
        conf_mask = pred[:, 4] > self.threshold
        pred = pred[conf_mask]

        # 取出检测框boxes
        boxes = pred[:, 0:4].astype("int")
        boxes = self.centerpoint2xywh(boxes)

        # 计算置信度
        confidences = pred[:, 5:].max(-1) * pred[:, 4]

        # class id 分类id
        pred_classes = np.argmax(pred[:, 5:], axis=1)

        idxs = cv2.dnn.NMSBoxes(
            boxes, confidences, self.threshold, self.iou_thres)

        return im, boxes[idxs], confidences[idxs], pred_classes[idxs]


def main(opt):
    det = Detector(opt)
    image = cv2.imread(opt.source)

    shape = (det.img_size, det.img_size)

    im0, pred_boxes, pred_confes, pred_classes = det.detect(image)
    if len(pred_boxes) > 0:
        for i, _ in enumerate(pred_boxes):
            box = pred_boxes[i]
            left, top, width, height = box[0], box[1], box[2], box[3]
            box = (left, top, left + width, top + height)
            box = np.squeeze(
                scale_coords(shape, np.expand_dims(box, axis=0).astype("float"), im0.shape[:2]).round(),
                axis=0).astype("int")
            x0, y0, x1, y1 = box[0], box[1], box[2], box[3]
            cv2.rectangle(image, (x0, y0), (x1, y1), (0, 0, 255), thickness=2)
            cv2.putText(image, '{0}--{1:.2f}'.format(pred_classes[i], pred_confes[i]), (x0, y0 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), thickness=1)
            cv2.imwrite('result.jpg', image)
    # cv2.imshow("detector", image)
    # cv2.waitKey()
    #
    # cv2.destroyAllWindows()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', nargs='+', type=str, default='yolov5s.om', help='onnx path(s)')
    parser.add_argument('--source', type=str, default='data/images/bus.jpg',
                        help='file/dir/URL/glob/screen/0(webcam)')
    parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS')
    parser.add_argument('--line-thickness', default=1, type=int, help='bounding box thickness (pixels)')
    parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
    parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
    opt = parser.parse_args()
    # print(opt)

    acl_resource = AclLiteResource()
    acl_resource.init()

    main(opt)
